package report

import (
	"bufio"
	"bytes"
	"encoding/csv"
	"errors"
	"fmt"
	"io"
	"log"
	"strings"

	"gorm.io/gorm"
)

//TODO 功能实现 有则插入，无则更新

//DBInsert 插入数据
type DBInsert struct {
	//errCount int
	//total    int
	//okCount  int
	table *Table
	kv    map[string]*Column
	//count    int
}

//Row .
type Row struct {
	Cols       []string
	Vals       []string
	TableName  string
	PrimaryPos int //主键的位置
}

//NewInsert .
func NewInsert(table *Table) *DBInsert {
	di := new(DBInsert)
	di.table = table
	di.kv = map[string]*Column{}
	for _, col := range table.Columns {
		key := strings.ToLower(col.Name)
		di.kv[key] = col
	}
	return di
}

func getTitles(buf *bufio.Reader) (cols []string) {
	line, err := buf.ReadString('\n')
	if err != nil {
		return
	}
	line = strings.TrimSpace(line)
	if line == "" {
		return nil
	}

	cols = strings.Split(line, ",")
	for i := range cols {
		cols[i] = strings.ToLower(cols[i])
	}
	return
}

func (m *DBInsert) getRow(cols []string, nCols int, primary string, buf *bufio.Reader) (row *Row, err error) {
	keys := make([]string, nCols)
	values := make([]string, nCols)
	idx := 0
	primaryPos := -1
	line, err := buf.ReadString('\n')
	if err == nil || len(line) > 0 {
		csvbuf := csv.NewReader(bytes.NewBufferString(line))
		ss, err := csvbuf.Read()
		if err != nil {
			return nil, err
		}
		err = nil
		n := len(ss)
		for i := 0; i < nCols && i < n; i++ {
			ss[i] = strings.TrimSpace(ss[i])
			ss[i] = strings.Trim(ss[i], "\"")
			if col, ok := m.kv[cols[i]]; ok {
				if primary == cols[i] {
					primaryPos = idx
				}
				keys[idx] = cols[i]
				values[idx] = col.GetValue(ss[i])
				if values[idx] == "" || values[idx] == "''" {
					//值是空值则忽略
					continue
				}
				idx++
			}
		}
	}
	row = new(Row)
	row.TableName = m.table.Name
	row.Cols = keys[:idx]
	row.Vals = values[:idx]
	row.PrimaryPos = primaryPos
	return
}

//生成插入的sql语句
func makeInsertSQL(row *Row) (sql string, err error) {
	data := bytes.NewBufferString("INSERT INTO " + row.TableName + " (")
	data.WriteString(strings.Join(row.Cols, ","))
	data.WriteString(") VALUES (")
	data.WriteString(strings.Join(row.Vals, ","))
	data.WriteString(")")
	sql = data.String()
	return
}

//生成更新的sql语句
func makeUpdateSQL(row *Row) (sql string, err error) {
	if row.PrimaryPos < 0 {
		err = errors.New("没有主键")
		return
	}
	sql = "UPDATE " + row.TableName + " SET "
	where := ""
	first := true
	for i := range row.Cols {
		if i == row.PrimaryPos {
			if row.Vals[i] == "" || row.Cols[i] == "" {
				err = errors.New("主键的条件是NULL")
				return
			}
			where = " WHERE " + row.Cols[i] + "=" + row.Vals[i]
			continue
		}
		if first {
			sql += row.Cols[i] + "=" + row.Vals[i]
			first = false
		} else {
			sql += ", " + row.Cols[i] + "=" + row.Vals[i]
		}
	}
	if where == "" {
		err = errors.New("主键的条件是NULL")
	} else {
		sql += where
	}
	return
}
func (m *DBInsert) getRowInsertSQL(cols []string, nCols int, buf *bufio.Reader) (sql string, err error) {
	keys := make([]string, nCols)
	values := make([]string, nCols)
	idx := 0
	line, err := buf.ReadString('\n')
	if err == nil || len(line) > 0 {
		csvbuf := csv.NewReader(bytes.NewBufferString(line))
		ss, err := csvbuf.Read()
		if err != nil {
			return "", err
		}
		err = nil
		n := len(ss)
		for i := 0; i < nCols && i < n; i++ {
			ss[i] = strings.TrimSpace(ss[i])
			ss[i] = strings.Trim(ss[i], "\"")
			if col, ok := m.kv[cols[i]]; ok {
				keys[idx] = cols[i]
				values[idx] = col.GetValue(ss[i])
				idx++
			}
		}
	} else {
		return
	}

	data := bytes.NewBufferString("INSERT INTO " + m.table.Name + " (")
	data.WriteString(strings.Join(keys[:idx], ","))
	data.WriteString(") VALUES (")
	data.WriteString(strings.Join(values[:idx], ","))
	data.WriteString(")")
	sql = data.String()
	return
}

//InsertRow 插入数据
// func (m *DBInsert) InsertRow(sql string) (err error) {
// 	if m.errCount > 20 {
// 		return
// 	}

// 	_, err = m.sess.Exec(sql)
// 	if err != nil {
// 		return
// 	}
// 	m.total++
// 	m.okCount++
// 	if m.okCount >= 2000 {
// 		log.Println("sess commit")

// 		err = m.sess.Commit()
// 		m.okCount = 0
// 		m.bInit = false
// 	}
// 	return
// }

//WriteError 写入错误
func WriteError(wirte io.StringWriter, lineNO int, err error) {
	if err == nil {
		return
	}
	wirte.WriteString(fmt.Sprintf("Line: %d ", lineNO))
	wirte.WriteString(err.Error())
	wirte.WriteString("\r\n")
}

//Insert .
func (m *DBInsert) Insert(db *gorm.DB, reader io.Reader) (err error) {
	buf := bufio.NewReader(reader)
	cols := getTitles(buf)
	nCols := len(cols)
	if nCols == 0 {
		err = errors.New("没有数据")
		return
	}
	for i := range cols {
		cols[i] = strings.Trim(cols[i], "\"")
	}
	errBuf := bytes.NewBuffer(nil)
	lineNO := 0
	sess := db.Begin()

	okCount := 0
	errCount := 0
	sess.Begin()
	defer func() {
		err = sess.Commit().Error
		WriteError(errBuf, lineNO, err)
		if errBuf.Len() > 0 {
			err = errors.New(errBuf.String())
		}
	}()
	for {
		lineNO++
		sql, err := m.getRowInsertSQL(cols, nCols, buf)
		if err != nil {
			if err == io.EOF {
				break
			}
			continue
		}
		err = sess.Exec(sql).Error
		if err != nil {
			errCount++
			if errCount < 20 {
				WriteError(errBuf, lineNO, err)
			}
		} else {
			okCount++
			if okCount%1000 == 0 {
				err = sess.Commit().Error
				WriteError(errBuf, lineNO, err)
				sess = db.Begin()
				WriteError(errBuf, lineNO, err)
			}
		}
	}

	return
}

//InsertOrUpdate 有则更新，无则插入
func (m *DBInsert) InsertOrUpdate(db *gorm.DB, primary string, reader io.Reader) (err error) {
	buf := bufio.NewReader(reader)
	cols := getTitles(buf)
	nCols := len(cols)
	if nCols == 0 {
		err = errors.New("没有数据")
		return
	}
	for i := range cols {
		cols[i] = strings.Trim(cols[i], "\"")
	}
	errBuf := bytes.NewBuffer(nil)
	lineNO := 0
	sess := db.Begin()

	okCount := 0
	errCount := 0
	defer func() {
		sess.Commit()
		WriteError(errBuf, lineNO, err)
		if errBuf.Len() > 0 {
			err = errors.New(errBuf.String())
		}
	}()
	for {
		lineNO++
		row, err := m.getRow(cols, nCols, primary, buf)
		if err != nil {
			if err == io.EOF {
				break
			}
			WriteError(errBuf, lineNO, err)
			continue
		}
		sql, err := makeInsertSQL(row)
		if err != nil {
			if errCount < 20 {
				WriteError(errBuf, lineNO, err)
				continue
			}
		}
		err = sess.Exec(sql).Error
		if err != nil {
			log.Println(sql)
			sql, err = makeUpdateSQL(row)
			if err != nil {
				WriteError(errBuf, lineNO, err)
				log.Println(sql)
				continue
			}
			err = sess.Exec(sql).Error
			errCount++
			if errCount < 20 {
				if err != nil {
					log.Println(sql)
					WriteError(errBuf, lineNO, err)
				}
			}
		} else {
			okCount++
			if okCount%1000 == 0 {
				err = sess.Commit().Error
				WriteError(errBuf, lineNO, err)
				sess = db.Begin()
				WriteError(errBuf, lineNO, err)
			}
		}
	}

	return
}
